{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Trust but Verify - Inspection of Large Image Collections</h1>\n",
    "\n",
    "This notebook and accompanying [Python script](characterize_data.py) illustrate the use of SimpleITK as a tool for efficient data inspection on large image collections, as part of familiarizing oneself with the data and performing cleanup prior to its use in deep learning or any other supervised machine learning approach.\n",
    "\n",
    "The reasons for inspecting your data before using it include:\n",
    "1. Identification of corrupt images.\n",
    "2. Identification of erroneous images (label noise).\n",
    "3. Assessment of data quality and variability in terms of intensity range, image resolution, and pixel types.\n",
    "4. Reduction of workload, identifying redundant information content (e.g. a greyscale/single channel image masquerading as a color/three channel image - think x-ray in jpg file).\n",
    "\n",
    "\n",
    "We inspect our data in two ways:\n",
    "1. Summary of image characteristics found in a directory structure (generic and DICOM specific).\n",
    "2. Visual inspection of the image content.\n",
    "\n",
    "SimpleITK allows us to easily control for the image types we are interested in via the ImageIO. The currently supported/registered IO types are defined by the following strings:\n",
    "* 'BMPImageIO'\n",
    "* 'BioRadImageIO'\n",
    "* 'Bruker2dseqImageIO'\n",
    "* 'GDCMImageIO'\n",
    "* 'GE4ImageIO'\n",
    "* 'GE5ImageIO'\n",
    "* 'GiplImageIO'\n",
    "* 'HDF5ImageIO'\n",
    "* 'JPEGImageIO'\n",
    "* 'JPEG2000ImageIO'\n",
    "* 'LSMImageIO'\n",
    "* 'MINCImageIO'\n",
    "* 'MRCImageIO'\n",
    "* 'MetaImageIO'\n",
    "* 'NiftiImageIO'\n",
    "* 'NrrdImageIO'\n",
    "* 'PNGImageIO'\n",
    "* 'StimulateImageIO'\n",
    "* 'TIFFImageIO'\n",
    "* 'VTKImageIO'\n",
    "* '' - empty string denotes all image file formats.\n",
    "\n",
    "To see the set of ImageIO types supported by your version of SimpleITK, call [ImageFileReader::GetRegisteredImageIOs()](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ImageReaderBase.html#a4cbb7db3eb3796eee8d89a1aaf011511) or simply print an ImageFileReader object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import shutil\n",
    "import subprocess\n",
    "import multiprocessing\n",
    "from functools import partial\n",
    "\n",
    "# We use the multiprocess package instead of the official \n",
    "# multiprocessing as it currently has several issues as discussed\n",
    "# on the software carpentry page: https://hpc-carpentry.github.io/hpc-python/06-parallel/\n",
    "import multiprocess as mp\n",
    "import platform\n",
    "\n",
    "import hashlib\n",
    "import tempfile\n",
    "import pickle\n",
    "\n",
    "%matplotlib notebook\n",
    "import matplotlib.pyplot as plt\n",
    "import ipywidgets as widgets\n",
    "\n",
    "\n",
    "#utility method that either downloads data from the Girder repository or\n",
    "#if already downloaded returns the file name for reading from disk (cached data)\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "OUTPUT_DIR = 'Output'\n",
    "\n",
    "#Maximal number of parallel processes we run.\n",
    "MAX_PROCESSES = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "For convenience, we use the heterogeneous notebook data. This dataset includes both 2D and 3D images, color and greyscale and 3D images stored as a set of 2D slices (DICOM series). Before starting down this path, make sure you have downloaded all notebook data. Downloading the data is described in the [setup notebook](00_Setup.ipynb). \n",
    "\n",
    "\n",
    "Generally speaking, in the context of deep learning most datasets will be larger and more homogeneous. A nicely sized dataset that is \"just right\" in terms of download time, yet is large enough to illustrate the utility of data inspection, more than 7000 images, is the OpenI Indiana chest x-ray dataset (+100GB of DICOM images in a [single tgz file](https://openi.nlm.nih.gov/imgs/collections/NLMCXR_dcm.tgz)). \n",
    "\n",
    "The publication describing the dataset is: D. Demner-Fushman et. al., \"Preparing a collection of radiology examinations for distribution and retrieval\", J Am Med Inform Assoc., 23(2):304-310, 2016.\n",
    "\n",
    "Using the Indiana dataset, see how long it takes you to identify several images that should have not been included. See if you could identify them only using the textual based csv summary report file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_root_dir = '../Data'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Characterizing  image set\n",
    "\n",
    "To characterize the image set we have written a [Python script](characterize_data.py) that you should run from the command line. This script is very flexible and allows you to robustly characterize your image set. Try the various options and learn more about your data. You'd be surprised how many times the data isn't what you thought it is when only relying on visual inspection. The script allows you to inspect your data both on a file by file basis and as DICOM series where an image (volume) is stored in multiple files.\n",
    "\n",
    "File by file:\n",
    "```\n",
    "python characterize_data.py data output/generic_image_data_report.csv per_file \\\n",
    "--imageIO \"\" --external_applications ./dciodvfy --external_applications_headings \"DICOM Compliant\" \\\n",
    "--metadata_keys \"0008|0060\" \"0018|5101\" --metadata_keys_headings \"modality\" \"radiographic view\"\n",
    "```\n",
    "\n",
    "DICOM series:\n",
    "```\n",
    "python characterize_data.py data output/DICOM_image_data_report.csv per_series \\\n",
    "--metadata_keys \"0008|0060\" \"0018|5101\" --metadata_keys_headings \"modality\" \"radiographic view\"  \n",
    "```\n",
    "\n",
    "\n",
    "After characterizing the image set we turn to visual inspection. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visual inspection of  image set\n",
    "\n",
    "While the reports generated above provide us with many insights with respect to the image content they are not sufficient for identifying erroneous images included in the dataset. For example if we expect our data to contain only frontal, AP or PA, chest x-rays (CXRs) you often also find lateral CXRs in the mix. In theory this would be indicated via the radiographic view DICOM tag. Unfortunately, the tag does not always have the correct value, and when images are converted from DICOM to some other format this information is often lost.\n",
    "\n",
    "We therefore resort to visual inspection. There's nothing like a human to quickly scan an image collection and identify outliers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_image(img, projection_axis, thumbnail_size):\n",
    "    '''\n",
    "    Create a grayscale thumbnail image from the given image. If the image is 3D it is \n",
    "    projected to 2D using a Maximum Intensity Projection (MIP) approach. Color images\n",
    "    are converted to grayscale, and high dynamic range images are window leveled using\n",
    "    a robust approach.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    img (SimpleITK.Image): A 2D or 3D grayscale or sRGB image.\n",
    "    projection_axis(int in [0,2]): The axis along which we project 3D images.\n",
    "    thumbnail_size (list/tuple(int)): The 2D sizes of the thumbnail.  \n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    2D SimpleITK image with sitkUInt8 pixel type.\n",
    "    \n",
    "    '''\n",
    "    if img.GetDimension()==3 and img.GetSize()[2]==1: #2D image masquerading as 3D image\n",
    "        img = img[:,:,0]\n",
    "    elif img.GetDimension() == 3: #3D image projected along projection_axis direction\n",
    "        img = sitk.MaximumProjection(img,projection_axis)\n",
    "        slc = list(img.GetSize())\n",
    "        slc[projection_axis] = 0 \n",
    "        img = sitk.Extract(img,slc)\n",
    "    if img.GetNumberOfComponentsPerPixel() == 3: #sRGB image, convert to gray\n",
    "        # Convert sRGB image to gray scale and rescale results to [0,255]    \n",
    "        channels = [sitk.VectorIndexSelectionCast(img,i, sitk.sitkFloat32) for i in range(img.GetNumberOfComponentsPerPixel())]\n",
    "        #linear mapping\n",
    "        I = 1/255.0*(0.2126*channels[0] + 0.7152*channels[1] + 0.0722*channels[2])\n",
    "        #nonlinear gamma correction\n",
    "        I = I*sitk.Cast(I<=0.0031308,sitk.sitkFloat32)*12.92 + I**(1/2.4)*sitk.Cast(I>0.0031308,sitk.sitkFloat32)*1.055-0.055\n",
    "        img = sitk.Cast(sitk.RescaleIntensity(I), sitk.sitkUInt8)\n",
    "    else:\n",
    "        # To deal with high dynamic range images that also contain outlier intensities \n",
    "        # we use window-level intensity mapping and set the window:\n",
    "        # to [max(Q1 - w*IQR, min_intensity), min(Q3 + w*IQR, max_intensity)]\n",
    "        # IQR = Q3-Q1\n",
    "        # The bounds which should exclude outliers are defined by the parameter w,\n",
    "        # where 1.5 is a standard default value (same as used in box and\n",
    "        # whisker plots to define whisker lengths).\n",
    "        w=1.5\n",
    "        min_val,q1_val,q3_val,max_val = np.percentile(sitk.GetArrayViewFromImage(img).flatten(), [0,25,75,100])\n",
    "        min_max = [np.max([(1.0+w)*q1_val-w*q3_val, min_val]), np.min([(1.0+w)*q3_val-w*q1_val, max_val])]\n",
    "        wl_image = sitk.IntensityWindowing(img, windowMinimum=min_max[0], windowMaximum=min_max[1], outputMinimum=0.0, outputMaximum=255.0)\n",
    "        img = sitk.Cast(wl_image, sitk.sitkUInt8)\n",
    "    res = sitk.Resample(img, size=thumbnail_size,\n",
    "                        transform=sitk.Transform(),interpolator=sitk.sitkLinear,\n",
    "                        outputOrigin = img.GetOrigin(), \n",
    "                        outputSpacing = [(sz-1)*spc/(nsz-1) for nsz,sz,spc in zip(thumbnail_size, img.GetSize(), img.GetSpacing())],\n",
    "                        outputDirection = img.GetDirection(),\n",
    "                        defaultPixelValue=0,\n",
    "                        outputPixelType=img.GetPixelID())\n",
    "    res.SetOrigin([0,0])\n",
    "    res.SetSpacing([1,1])\n",
    "    res.SetDirection([1,0,0,1])\n",
    "    return res\n",
    "\n",
    "\n",
    "def visualize_single_file(file_name, imageIO, projection_axis, thumbnail_size):\n",
    "    image_file_name = ''\n",
    "    image = None\n",
    "    try:\n",
    "        reader = sitk.ImageFileReader()\n",
    "        reader.SetImageIO(imageIO)\n",
    "        reader.SetFileName(file_name)\n",
    "        img = reader.Execute()\n",
    "        image = process_image(img, projection_axis, thumbnail_size)\n",
    "        image_file_name = file_name\n",
    "    except:\n",
    "        pass\n",
    "    return (image_file_name, image)\n",
    "\n",
    "\n",
    "def visualize_files(root_dir, imageIO='', projection_axis = 2, thumbnail_size=[64,64], tile_size=[20,20]):\n",
    "    '''\n",
    "    This function traverses the directory structure reading all user selected images \n",
    "    (selction based on the image file format specified by the caller). All images are converted to 2D grayscale \n",
    "    in [0,255] as follows:\n",
    "    * Images with three channels are assumed to be in sRGB color space and converted to grayscale.\n",
    "    * Grayscale images are window-levelled using robust values for the window-level accomodating \n",
    "    * for outlying intensity values.\n",
    "    * 3D images are converted to 2D using maximum intensity projection along the user specified projection axis.    \n",
    "    Parameters\n",
    "    ----------\n",
    "    root_dir (str): Path to the root of the data directory. Traverse the directory structure\n",
    "                    and try to read every file as an image using the given imageIO.\n",
    "    imageIO (str): Name of image IO to use. To see the list of registered image IOs use the \n",
    "                   ImageFileReader::GetRegisteredImageIOs() or print an ImageFileReader.\n",
    "                   The empty string indicates to read all file formats supported by SimpleITK.\n",
    "    projection_axis (int in [0,2]): 3D images are converted to 2D using mean projection along the\n",
    "                                    specified axis.\n",
    "    thumbnail_size (2D tuple/list): The size of the 2D image tile used for visualization.\n",
    "    tile_size (2D tuple/list): Number of tiles to use in x and y.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    tuple(SimpleITK.Image, list): faux_volume comprised of tiles, file_name_list corrosponding\n",
    "                                  to the image tiles.\n",
    "                                  The SimpleITK image contains the meta-data 'thumbnail_size' and\n",
    "                                  'tile_size'.\n",
    "    '''\n",
    "    image_file_names = []\n",
    "    faux_volume = None\n",
    "    images = []\n",
    "\n",
    "    all_file_names = []\n",
    "    for dir_name, subdir_names, file_names in os.walk(root_dir):\n",
    "        all_file_names += [os.path.join(os.path.abspath(dir_name), fname) for fname in file_names]\n",
    "    if platform.system() == 'Windows':\n",
    "        res = map(partial(visualize_single_file,\n",
    "                          imageIO=imageIO, \n",
    "                          projection_axis=projection_axis, \n",
    "                          thumbnail_size=thumbnail_size), all_file_names)\n",
    "    else:\n",
    "        with mp.Pool(processes=MAX_PROCESSES) as pool:\n",
    "            res = pool.map(partial(visualize_single_file,\n",
    "                                   imageIO=imageIO, \n",
    "                                   projection_axis=projection_axis, \n",
    "                                   thumbnail_size=thumbnail_size), all_file_names)\n",
    "    res = [data for data in res if data[1] is not None]\n",
    "    if res:\n",
    "        image_file_names, images = zip(*res)\n",
    "        if image_file_names:\n",
    "            faux_volume = create_tile_volume(images, tile_size)\n",
    "            faux_volume.SetMetaData('thumbnail_size', ' '.join([str(v) for v in thumbnail_size]))\n",
    "            faux_volume.SetMetaData('tile_size', ' '.join([str(v) for v in tile_size]))\n",
    "    return (faux_volume, image_file_names)\n",
    "\n",
    "\n",
    "def create_tile_volume(images, tile_size):\n",
    "    '''\n",
    "    Create a faux-volume from a list of images. Each slice in the volume \n",
    "    is constructed from tile_size[0]*tile_size[1] images. The slices are \n",
    "    then joined to form the faux volume.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    images (list(SimpleITK.Image(2D, sitkUInt8))): image list that we tile.\n",
    "    tile_size (2D tuple/list): Number of tiles to use in x and y.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    SimpleITK.Image(3D, sitkUInt8): Volume comprised of tiled image slices. \n",
    "                                    Order of tiles matches the order of the input list.\n",
    "    '''\n",
    "    step_size = tile_size[0]*tile_size[1]\n",
    "    faux_volume = [sitk.Tile(images[i:i+step_size], tile_size, 0) for i in range(0, len(images), step_size)]\n",
    "    #if last tile image is smaller than others, add background content to match the size\n",
    "    if len(faux_volume)>1 and \\\n",
    "       (faux_volume[-1].GetHeight()!=faux_volume[-2].GetHeight() or faux_volume[-1].GetWidth()!=faux_volume[-2].GetWidth()):\n",
    "        img = sitk.Image(faux_volume[-2])*0\n",
    "        faux_volume[-1] = sitk.Paste(img, faux_volume[-1], faux_volume[-1].GetSize(), [0,0], [0,0])       \n",
    "    return sitk.JoinSeries(faux_volume)\n",
    "    \n",
    "\n",
    "def visualize_series(root_dir, projection_axis = 2, thumbnail_size=[64,64], tile_size=[20,20]):\n",
    "    '''\n",
    "    This function traverses the directory structure reading all DICOM series (a series can reside\n",
    "    in multiple directories). All images are converted to 2D grayscale in [0,255] as follows:\n",
    "    * Images with three channels are assumed to be in sRGB color space and converted to grayscale.\n",
    "    * Grayscale images are window-levelled using robust values for the window-level accomodating \n",
    "    * for outlying intensity values.\n",
    "    * 3D images are converted to 2D using maximum intensity projection along the user specified projection axis.    \n",
    "    Parameters\n",
    "    ----------\n",
    "    root_dir (str): Path to the root of the data directory. Traverse the directory structure\n",
    "                    and try to read every file as an image using the given imageIO.\n",
    "    projection_axis (int in [0,2]): 3D images are converted to 2D using mean projection along the\n",
    "                                    specified axis.\n",
    "    thumbnail_size (2D tuple/list): The size of the 2D image tile used for visualization.\n",
    "    tile_size (2D tuple/list): Number of tiles to use in x and y.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    tuple(SimpleITK.Image, list): faux_volume comprised of tiles, series_file_name_lists corrosponding\n",
    "                                  to the image tiles. The series_file_name_lists is a list of lists where\n",
    "                                  the sublists are DICOM series.\n",
    "                                  The SimpleITK image contains the meta-data 'thumbnail_size' and\n",
    "                                  'tile_size'.\n",
    "    '''    \n",
    "    #collect the file names of all series into a dictionary with the key being\n",
    "    #study:series.\n",
    "    all_series_files = {}\n",
    "    reader = sitk.ImageFileReader()\n",
    "    for dir_name, subdir_names, file_names in os.walk(root_dir):\n",
    "        sids = sitk.ImageSeriesReader_GetGDCMSeriesIDs(dir_name)\n",
    "        for sid in sids: # Using absolute file names so that the list is valid no matter where the script is run\n",
    "            file_names = [os.path.abspath(fname) for fname in sitk.ImageSeriesReader_GetGDCMSeriesFileNames(dir_name, sid)]\n",
    "            reader.SetFileName(file_names[0])\n",
    "            reader.ReadImageInformation()\n",
    "            study = reader.GetMetaData('0020|000d')\n",
    "            key = '{0}:{1}'.format(study,sid)\n",
    "            if key in all_series_files:\n",
    "                all_series_files[key].extend(file_names)\n",
    "            else:\n",
    "                all_series_files[key] = list(file_names)\n",
    "    images_and_files = [(process_series(series_data, projection_axis, thumbnail_size), series_data[1]) \n",
    "                         for series_data in all_series_files.items()]\n",
    "    images,files = zip(*images_and_files)\n",
    "    faux_volume = create_tile_volume(images, tile_size)\n",
    "    faux_volume.SetMetaData('thumbnail_size', ' '.join([str(v) for v in thumbnail_size]))\n",
    "    faux_volume.SetMetaData('tile_size', ' '.join([str(v) for v in tile_size]))\n",
    "    return (faux_volume, files)\n",
    "\n",
    "\n",
    "def process_series(series_data, projection_axis, thumbnail_size):\n",
    "    reader = sitk.ImageSeriesReader()\n",
    "    _,sid = series_data[0].split(':')\n",
    "    file_names = series_data[1]\n",
    "    # As the files comprising a series with multiple files can reside in \n",
    "    # separate directories and SimpleITK expects them to be in a single directory \n",
    "    # we use a tempdir and symbolic links to enable SimpleITK to read the series as\n",
    "    # a single image.\n",
    "    with tempfile.TemporaryDirectory() as tmpdirname:\n",
    "        if platform.system() == 'Windows':\n",
    "            for i, fname in enumerate(file_names):\n",
    "                shutil.copy(fname, \n",
    "                            os.path.join(tmpdirname,str(i)))\n",
    "        else:\n",
    "            for i, fname in enumerate(file_names):\n",
    "                os.symlink(fname,\n",
    "                           os.path.join(tmpdirname,str(i)))\n",
    "        reader.SetFileNames(sitk.ImageSeriesReader_GetGDCMSeriesFileNames(tmpdirname, sid))\n",
    "        img = reader.Execute()\n",
    "        return process_image(img, projection_axis, thumbnail_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The class in the following cell `ImageSelection` provides a GUI for displaying and interacting with a tiled faux volume. The user can scroll through the faux volume \"slices\", zoom in, pan, and select images. When the user clicks on an image a user specified action is taken, `selection_func` is invoked with the file name(s) of the associated image. Two useful user functions are provided at the end of the code cell:\n",
    "* `show_image` - displays the original image at full resolution using an external viewer (both 2D and 3D).\n",
    "* `rm_image` - for the more confident user, delete the file(s) associated with the selected image (data cleanup).\n",
    "\n",
    "The recommended usage is with the `show_image` ensuring that the images you selected should truly be deleted and then deleting them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageSelection(object):\n",
    "    def __init__(self, tiled_faux_vol, image_files_list, selection_func=None, figure_size=(8,8)):\n",
    "        self.tiled_faux_vol = tiled_faux_vol\n",
    "        self.thumbnail_size = [int(v) for v in self.tiled_faux_vol.GetMetaData('thumbnail_size').split()]\n",
    "        self.tile_size = [int(v) for v in self.tiled_faux_vol.GetMetaData('tile_size').split()]        \n",
    "        self.npa = sitk.GetArrayViewFromImage(self.tiled_faux_vol)\n",
    "        self.point_indexes = []\n",
    "        self.selected_image_indexes = []\n",
    "        self.image_files_list = image_files_list\n",
    "        self.selection_func = selection_func\n",
    "        \n",
    "        ui = self.create_ui()\n",
    "        display(ui)\n",
    "\n",
    "        # Create a figure.\n",
    "        self.fig, self.axes = plt.subplots(1,1,figsize=figure_size)\n",
    "        # Connect the mouse button press to the canvas (__call__ method is the invoked callback).\n",
    "        self.fig.canvas.mpl_connect('button_press_event', self)\n",
    "\n",
    "        # Display the data and the controls, first time we display the image is outside the \"update_display\" method\n",
    "        # as that method relies on the previous zoom factor which doesn't exist yet.\n",
    "        self.axes.imshow(self.npa[self.slice_slider.value,:,:] if self.slice_slider else self.npa,\n",
    "                         cmap=plt.cm.Greys_r)\n",
    "        self.fig.tight_layout()\n",
    "        self.update_display()\n",
    "\n",
    "\n",
    "    def create_ui(self):\n",
    "        # Create the active GUI components. Height and width are specified in 'em' units. This is\n",
    "        # a HTML size specification, size relative to current font size.\n",
    "        self.viewing_checkbox = widgets.RadioButtons(description= 'Interaction mode:',\n",
    "                                                     options= ['edit', 'view'],\n",
    "                                                     value = 'edit')\n",
    "\n",
    "        self.clearlast_button = widgets.Button(description= 'Clear Last',\n",
    "                                               width= '7em',\n",
    "                                               height= '3em')\n",
    "        self.clearlast_button.on_click(self.clear_last)\n",
    "\n",
    "        self.clearall_button = widgets.Button(description= 'Clear All',\n",
    "                                              width= '7em',\n",
    "                                              height= '3em')\n",
    "        self.clearall_button.on_click(self.clear_all)\n",
    "\n",
    "        # Slider is only created if a 3D image, otherwise no need.\n",
    "        self.slice_slider = None\n",
    "        if self.npa.ndim == 3:\n",
    "            self.slice_slider = widgets.IntSlider(description='image z slice:',\n",
    "                                                  min=0,\n",
    "                                                  max=self.npa.shape[0]-1,\n",
    "                                                  step=1,\n",
    "                                                  value = int((self.npa.shape[0]-1)/2),\n",
    "                                                  width='20em')\n",
    "            self.slice_slider.observe(self.on_slice_slider_value_change, names='value')\n",
    "            bx0 = widgets.Box(padding=7, children=[self.slice_slider])\n",
    "\n",
    "        # Layout of GUI components. This is pure ugliness because we are not using a GUI toolkit. Layout is done\n",
    "        # using the box widget and padding so that the visible GUI components are spaced nicely.\n",
    "        bx1 = widgets.Box(padding=7, children = [self.viewing_checkbox])\n",
    "        bx2 = widgets.Box(padding = 15, children = [self.clearlast_button])\n",
    "        bx3 = widgets.Box(padding = 15, children = [self.clearall_button])\n",
    "        return widgets.HBox(children=[widgets.HBox(children=[bx1, bx2, bx3]),bx0]) if self.slice_slider else widgets.HBox(children=[widgets.HBox(children=[bx1, bx2, bx3])])\n",
    "\n",
    "\n",
    "    def on_slice_slider_value_change(self, change):\n",
    "        self.update_display()\n",
    "\n",
    "    def update_display(self):\n",
    "        # We want to keep the zoom factor which was set prior to display, so we log it before\n",
    "        # clearing the axes.\n",
    "        xlim = self.axes.get_xlim()\n",
    "        ylim = self.axes.get_ylim()\n",
    "\n",
    "        # Draw the image and localized points.\n",
    "        self.axes.clear()\n",
    "        self.axes.imshow(self.npa[self.slice_slider.value,:,:] if self.slice_slider else self.npa,\n",
    "                         cmap=plt.cm.Greys_r)\n",
    "        for i, pnt in enumerate(self.point_indexes):\n",
    "            if(self.slice_slider and int(pnt[2] + 0.5) == self.slice_slider.value) or not self.slice_slider:\n",
    "                self.axes.scatter(pnt[0], pnt[1], s=90, marker='+', color='yellow')\n",
    "                # Get point in pixels.\n",
    "        self.axes.set_title('selected {0} images'.format(len(self.point_indexes)))\n",
    "        self.axes.set_axis_off()\n",
    "\n",
    "\n",
    "        # Set the zoom factor back to what it was before we cleared the axes, and rendered our data.\n",
    "        self.axes.set_xlim(xlim)\n",
    "        self.axes.set_ylim(ylim)\n",
    "\n",
    "        self.fig.canvas.draw_idle()\n",
    "\n",
    "    def clear_all(self, button):\n",
    "        del self.point_indexes[:]\n",
    "        del self.selected_image_indexes[:]\n",
    "        self.update_display()\n",
    "\n",
    "    def clear_last(self, button):\n",
    "        if self.point_indexes:\n",
    "            self.point_indexes.pop()\n",
    "            self.selected_image_indexes.pop()\n",
    "            self.update_display()\n",
    "\n",
    "    def get_selected_images(self):\n",
    "        return [self.image_files_list[index] for index in self.selected_image_indexes]\n",
    "\n",
    "    def __call__(self, event):\n",
    "        if self.viewing_checkbox.value == 'edit':\n",
    "            if event.inaxes==self.axes:\n",
    "                x = int(round(event.xdata))\n",
    "                y = int(round(event.ydata))\n",
    "                z = self.slice_slider.value\n",
    "                image_index = z*self.tile_size[0]*self.tile_size[1] + \\\n",
    "                              int(y/self.thumbnail_size[1])*self.tile_size[0] + \\\n",
    "                              int(x/self.thumbnail_size[0])\n",
    "                if image_index<len(self.image_files_list):\n",
    "                    #If new selection add it, otherwise just redisplay the image by calling Show.\n",
    "                    if image_index not in self.selected_image_indexes:\n",
    "                        self.point_indexes.append((event.xdata, event.ydata, self.slice_slider.value) if self.slice_slider else (event.xdata, event.ydata))\n",
    "                        self.selected_image_indexes.append(image_index)\n",
    "                        self.update_display()\n",
    "                    if self.selection_func:\n",
    "                        self.selection_func(self.image_files_list[image_index])\n",
    "                        \n",
    "def show_image(image_file_name):\n",
    "    if isinstance(image_file_name, str):\n",
    "        img = sitk.ReadImage(image_file_name)\n",
    "    else:\n",
    "        # As the files comprising a DICOM series with multiple files can reside in \n",
    "        # separate directories and SimpleITK expects them to be in a single directory \n",
    "        # we use a tempdir and symbolic links to enable SimpleITK to read the series as\n",
    "        # a single image.\n",
    "        with tempfile.TemporaryDirectory() as tmpdirname:\n",
    "            if platform.system() == 'Windows':\n",
    "                for i, fname in enumerate(image_file_name):\n",
    "                    shutil.copy(os.path.abspath(fname),\n",
    "                               os.path.join(tmpdirname,str(i)))\n",
    "            else:\n",
    "                for i, fname in enumerate(image_file_name):\n",
    "                    os.symlink(os.path.abspath(fname),\n",
    "                               os.path.join(tmpdirname,str(i)))                \n",
    "            img = sitk.ReadImage(sitk.ImageSeriesReader_GetGDCMSeriesFileNames(tmpdirname))\n",
    "    sitk.Show(img)\n",
    "\n",
    "    \n",
    "def rm_image(image_file_name):\n",
    "    try: #if file doesn't exist an exception is thrown.\n",
    "        if isinstance(image_file_name, basestring):\n",
    "            os.remove(image_file_name)\n",
    "        else:\n",
    "            for f in image_file_name:\n",
    "                os.remove(f)\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next run our visualization based analysis. For large datasets this may take some time, so we save the results via pickling so that we can look at them at a later point in time.\n",
    "\n",
    "We start with the file based approach. Notice that with this approach the display is dominated by slices from several DICOM series. After selecting our images of interest we print the associated file names."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "faux_image_volume_file_name = os.path.join(OUTPUT_DIR, 'faux_image_volume.pkl')\n",
    "faux_image_file_list_name = os.path.join(OUTPUT_DIR, 'faux_image_file_list.pkl')\n",
    "faux_volume_image_files, image_file_list = visualize_files(data_root_dir, imageIO = '', projection_axis=2, thumbnail_size=[64,64], tile_size=[30,20])\n",
    "with open(faux_image_volume_file_name, 'wb') as fp:\n",
    "    pickle.dump(faux_volume_image_files, fp)\n",
    "with open(faux_image_file_list_name, 'wb') as fp:\n",
    "    pickle.dump(image_file_list, fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "with open(faux_image_volume_file_name, 'rb') as fp:\n",
    "    faux_volume_image_files = pickle.load(fp)\n",
    "with open(faux_image_file_list_name, 'rb') as fp:\n",
    "    image_file_list = pickle.load(fp)\n",
    "\n",
    "image_selection_gui = ImageSelection(faux_volume_image_files, image_file_list, \n",
    "                                     figure_size = (10,8), selection_func=show_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_files = image_selection_gui.get_selected_images()\n",
    "print(selected_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we look at our data using the DICOM series based approach.\n",
    "\n",
    "After selecting our images of interest we print the associated files. Notice that for the series based approach for some images there is a single file association and for some multiple files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "faux_series_volume_file_name = os.path.join(OUTPUT_DIR, 'faux_series_volume.pkl')\n",
    "faux_series_file_list_name = os.path.join(OUTPUT_DIR, 'faux_series_file_list.pkl')\n",
    "faux_volume_image_files, image_file_list = visualize_series(data_root_dir, projection_axis=2, thumbnail_size=[64,64], tile_size=[30,20])\n",
    "with open(faux_series_volume_file_name, 'wb') as fp:\n",
    "    pickle.dump(faux_volume_image_files, fp)\n",
    "with open(faux_series_file_list_name, 'wb') as fp:\n",
    "    pickle.dump(image_file_list, fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "with open(faux_series_volume_file_name, 'rb') as fp:\n",
    "    faux_volume_image_files = pickle.load(fp)\n",
    "with open(faux_series_file_list_name, 'rb') as fp:\n",
    "    image_file_list = pickle.load(fp)\n",
    "\n",
    "image_selection_gui2 = ImageSelection(faux_volume_image_files, image_file_list, \n",
    "                                     figure_size = (5,4), selection_func=show_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_files = image_selection_gui2.get_selected_images()\n",
    "print(selected_files)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
